rm(list=ls())
library(cmdstanr)
library(bayesplot)
library(posterior)
library(mcmcplots)
library(magrittr)
library(multipanelfigure)
library(data.table)
library(tidylog)
library(tidyverse)

setwd("")



data <- list(
   n=6,
   P = c(691165,506885,547164,687621,532921,580489),
   pyr = c(1114,4962,1809,760,2292,732),
   age_g2 = c(0,1,0,0,1,0),
   age_g3 = c(0,0,1,0,0,1),
   sex_g  = c(1,1,1,0,0,0),
   events_cohort = c(25,156,51,13,60,14),
   n_known = c(3118,15865,6621,2255,7837,2377),
   events_extra = c(87,86,33,18,27,11)
)

head(data)  

example7 <- cmdstan_model("STAN/example7.stan")


initial_values <- list(
  list(beta0=-4.1, beta=rep(0,3), gamma=rep(0,4)),
  list(beta0=-4.3, beta=rep(1,3), gamma=rep(-0.5,4)),
  list(beta0=-3.9, beta=rep(-1,3), gamma=rep(0.5,4))
  )

example7_fit <- example7$sample(data=data,
                                seed=123,
                                iter_warmup=5000,
                                iter_sampling=5000,
                                chains = 3,
                                parallel_chains=3,
                               save_warmup = TRUE,
                               thin=1,
                               max_treedepth=15,
                               init=initial_values)

example7_table <- example7_fit$summary() %>% setDT()
example7_table
range(example7_fit$summary()$rhat)



round(example7_fit$summary()[which(example7_fit$summary()$variable %in% paste('prev[',1:6,']',sep='')),c('median','q5','q95')]*100,2)
round(example7_fit$summary()[which(example7_fit$summary()$variable %in% paste('prev_known[',1:6,']',sep='')),c('median','q5','q95')]*100,2)
round(example7_fit$summary()[which(example7_fit$summary()$variable %in% paste('N[',1:6,']',sep='')),c('median','q5','q95')],-2)


example7_table %>% filter(variable %in% c('ResDevTotal_cohort','ResDevTotal_extra','ResDevTotal') )





posterior_sample <- data.frame(example7_fit$draws(format = "matrix",inc_warmup = FALSE))
N_size <- data$n
log_lik_theta_hat <- rep(0,6)
for(i in 1:N_size){
  log_lik_theta_hat[i] <- dpois(data$events_cohort[i], mean(posterior_sample[,paste0('mean_cohort.',i,'.')]),log=TRUE) +
                          dpois(data$events_extra[i], mean(posterior_sample[,paste0('mean_extra.',i,'.')]),log=TRUE)
}



total_log_lik_theta_hat <- sum(log_lik_theta_hat)
mean_log_lik_theta <- mean(posterior_sample[,'log_lik'])
#getwd()
pD <- 2 * (total_log_lik_theta_hat - mean_log_lik_theta)
DIC <- (-2 * total_log_lik_theta_hat) + (2*pD)
DIC3 <- (-2 * mean_log_lik_theta) + (pD)


